/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 * All rights reserved.
 *
 * This source code is licensed under the BSD-style license found in the
 * LICENSE file in the root directory of this source tree.
 */

#include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
#include <executorch/kernels/test/TestUtil.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>

#include <gtest/gtest.h>

using namespace ::testing;
using executorch::aten::ScalarType;
using executorch::aten::Tensor;
using std::optional;
using torch::executor::testing::TensorFactory;

class OpNativeGroupNormTest : public OperatorTest {
 protected:
  ::std::tuple<Tensor&, Tensor&, Tensor&> op_native_group_norm_out(
      const Tensor& input,
      const optional<Tensor>& weight,
      const optional<Tensor>& bias,
      int64_t N,
      int64_t C,
      int64_t HxW,
      int64_t group,
      double eps,
      Tensor& out0,
      Tensor& out1,
      Tensor& out2) {
    return torch::executor::aten::native_group_norm_outf(
        context_, input, weight, bias, N, C, HxW, group, eps, out0, out1, out2);
  }

  template <ScalarType DTYPE>
  struct NativeGroupNormTestCase {
    using ctype = typename TensorFactory<DTYPE>::ctype;

    // Size vector for the input/output
    const std::vector<int32_t> sizes;
    // Data for the input tensor; must agree with `sizes`.
    const std::vector<ctype> input_data;
    // Affine transform weight.
    const std::vector<ctype> weight_data;
    // Affine transform bias.
    const std::vector<ctype> bias_data;
    // Batch size N
    const int64_t N;
    // Number of channels C
    const int64_t C;
    // Spatial size HxW
    const int64_t HxW;
    // Number of groups
    const int64_t group;
    // a value added to the denominator for numerical stability
    const ctype eps;
    // The expected output data.
    const std::vector<ctype> expected_data;
    // The expected mean data.
    const std::vector<ctype> expected_mean_data;
    // The expected rstd data.
    const std::vector<ctype> expected_rstd_data;
  };

  /// Runs the provided test cases.
  template <ScalarType DTYPE>
  void run_test_cases(std::vector<NativeGroupNormTestCase<DTYPE>> test_cases) {
    TensorFactory<DTYPE> tf;
    for (const auto& test_case : test_cases) {
      Tensor in = tf.make(test_case.sizes, test_case.input_data);
      optional<Tensor> weight =
          tf.make({static_cast<int32_t>(test_case.C)}, test_case.weight_data);
      optional<Tensor> bias =
          tf.make({static_cast<int32_t>(test_case.C)}, test_case.bias_data);
      Tensor out0 = tf.zeros(test_case.sizes);
      Tensor out1 = tf.zeros(
          {static_cast<int32_t>(test_case.N),
           static_cast<int32_t>(test_case.group)},
          torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
      Tensor out2 = tf.zeros(
          {static_cast<int32_t>(test_case.N),
           static_cast<int32_t>(test_case.group)},
          torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);

      auto result = op_native_group_norm_out(
          in,
          weight,
          bias,
          test_case.N,
          test_case.C,
          test_case.HxW,
          test_case.group,
          test_case.eps,
          out0,
          out1,
          out2);
      EXPECT_TENSOR_CLOSE(out0, std::get<0>(result));

      Tensor expected = tf.make(test_case.sizes, test_case.expected_data);
      Tensor expected_mean = tf.make(
          {static_cast<int32_t>(test_case.N),
           static_cast<int32_t>(test_case.group)},
          test_case.expected_mean_data);
      Tensor expected_rstd = tf.make(
          {static_cast<int32_t>(test_case.N),
           static_cast<int32_t>(test_case.group)},
          test_case.expected_rstd_data);

      if constexpr (DTYPE == ScalarType::Half) {
        EXPECT_TENSOR_CLOSE_WITH_TOL(
            out0,
            expected,
            1e-2,
            executorch::runtime::testing::internal::kDefaultHalfAtol);
        EXPECT_TENSOR_CLOSE_WITH_TOL(
            out1,
            expected_mean,
            1e-2,
            executorch::runtime::testing::internal::kDefaultHalfAtol);
        EXPECT_TENSOR_CLOSE_WITH_TOL(
            out2,
            expected_rstd,
            1e-2,
            executorch::runtime::testing::internal::kDefaultHalfAtol);
      } else if constexpr (DTYPE == ScalarType::BFloat16) {
        EXPECT_TENSOR_CLOSE_WITH_TOL(
            out0,
            expected,
            1e-2,
            executorch::runtime::testing::internal::kDefaultBFloat16Atol);
        EXPECT_TENSOR_CLOSE_WITH_TOL(
            out1,
            expected_mean,
            1e-2,
            executorch::runtime::testing::internal::kDefaultBFloat16Atol);
        EXPECT_TENSOR_CLOSE_WITH_TOL(
            out2,
            expected_rstd,
            1e-2,
            executorch::runtime::testing::internal::kDefaultBFloat16Atol);
      } else {
        EXPECT_TENSOR_CLOSE(out0, expected);
        EXPECT_TENSOR_CLOSE(out1, expected_mean);
        EXPECT_TENSOR_CLOSE(out2, expected_rstd);
      }
    }
  }

  template <ScalarType DTYPE>
  void run_floating_point_test_cases() {
    std::vector<NativeGroupNormTestCase<DTYPE>> test_cases = {
        {
            {5, 6, 2, 2}, // sizes
            {-0.8125, 0.0625,  -2.7500, -3.0625, -1.1250, -2.1250, -1.3125,
             -4.0625, 2.8125,  -2.0625, 4.2500,  3.5000,  -0.3750, 1.6250,
             4.3125,  -1.0625, -2.8750, 3.3750,  4.9375,  4.0625,  -3.0625,
             -1.8750, -2.7500, -2.5625, -0.1875, -3.0000, -2.7500, 0.6875,
             -3.2500, -3.1875, 1.0000,  -4.6250, -0.1875, -1.7500, 4.5000,
             -1.8750, -2.6875, 4.8125,  -3.8125, -2.9375, -1.1875, 2.8750,
             0.7500,  2.8750,  1.1250,  -0.6250, -2.2500, -3.7500, 3.2500,
             -0.3750, -2.0625, -4.7500, 2.0625,  3.0000,  -3.1875, -4.1250,
             -3.7500, 1.2500,  -2.3125, 1.5625,  3.1250,  0.3125,  3.2500,
             -2.7500, -3.8125, -4.2500, -4.3125, -0.5625, -0.4375, 2.9375,
             -1.3750, -0.6250, -2.5625, -4.5625, 0.1250,  -3.5000, -5.0000,
             -1.0000, -4.6875, -0.6875, 1.1250,  1.8750,  -4.5000, 4.3125,
             4.5625,  0.2500,  -3.6250, 4.5625,  -3.5000, -2.1250, -3.6250,
             -2.9375, 3.6875,  3.9375,  4.3750,  3.0625,  2.4375,  2.0625,
             -2.4375, -3.9375, 3.6875,  2.7500,  -0.8750, -0.9375, 2.7500,
             -2.4375, -2.3750, -0.9375, -4.8750, 0.1875,  3.5000,  -2.0000,
             -0.2500, -2.7500, 0.3125,  1.2500,  -0.5625, 0.0000,  1.8125,
             1.0625}, // input_data
            {4.5625, -2.8750, -0.6875, 0.5625, -2.0625, -2.7500}, // weight_data
            {-0.5000, -2.7500, 1.1875, 3.6875, 3.8125, 4.6875}, // bias_data
            5, // N
            6, // C
            4, // HxW
            3, // group
            1e-5, // eps
            {3.419882,  6.578348,  -3.573864, -4.701888,
             -4.509254, -2.234663, -4.082768, 2.172355,
             0.838826,  2.270225,  0.416747,  0.636962,
             3.207030,  3.687500,  4.333131,  3.041869,
             5.547079,  1.649148,  0.674665,  1.220376,
             7.156189,  6.168714,  6.896327,  6.740410,
             3.509863,  -3.022041, -2.441427, 5.542011,
             -0.794903, -0.886369, -7.014627, 1.217361,
             1.120617,  1.463606,  0.091652,  1.491045,
             3.293219,  4.640229,  3.091168,  3.248319,
             4.895990,  1.114683,  3.092597,  1.114683,
             3.262238,  5.434066,  7.450763,  9.312329,
             5.570122,  0.101119,  -2.444796, -6.499403,
             -5.446074, -6.337338, -0.454995, 0.436269,
             2.228491,  0.871598,  1.838385,  0.786793,
             4.362284,  3.737805,  4.390039,  3.057817,
             5.814659,  6.202621,  6.258044,  2.932658,
             3.366583,  -0.623879, 4.475045,  3.588276,
             -0.082914, -4.936279, 6.438795,  -2.357929,
             0.714463,  -5.402106, 0.236606,  -5.879963,
             1.176247,  1.021916,  2.333727,  0.520341,
             4.275447,  3.549392,  2.896994,  4.275447,
             6.120910,  5.298480,  6.195676,  5.784461,
             2.033296,  1.833920,  1.485010,  2.531738,
             3.193988,  2.532378,  -5.406940, -8.053379,
             -6.467402, -5.425139, -1.395059, -1.325575,
             0.266062,  1.622680,  1.606336,  1.230405,
             2.809896,  3.893110,  4.601880,  3.425055,
             4.374411,  8.283354,  3.494898,  2.029045,
             6.088204,  4.915522,  1.136877,  2.700454}, // expected_data
            {-1.89843750,
             1.62500000,
             -0.09375000,
             -1.91406250,
             -0.49218744,
             -0.02343750,
             -0.77343756,
             0.08593753,
             -1.55468738,
             -2.73437500,
             1.07031238,
             0.35937503,
             0.34374997,
             -0.77343750,
             0.10937499}, // expected_mean_data
            {0.79116172,
             0.42708409,
             0.30238494,
             0.50903118,
             0.31929117,
             0.45128885,
             0.33067191,
             0.39473253,
             0.42994878,
             0.53187561,
             0.29930803,
             0.29000264,
             0.38669431,
             0.38038814,
             0.75809801}, // expected_rstd_data
        },
    };

    run_test_cases(test_cases);
  }

  // Runs death test cases.
  template <ScalarType DTYPE>
  void run_death_test_cases(
      std::vector<NativeGroupNormTestCase<DTYPE>> test_cases) {
    TensorFactory<DTYPE> tf;
    for (const auto& test_case : test_cases) {
      Tensor in = tf.make(test_case.sizes, test_case.input_data);
      std::optional<Tensor> weight, bias;
      if (!test_case.weight_data.empty()) {
        weight =
            tf.make({static_cast<int32_t>(test_case.C)}, test_case.weight_data);
      }
      if (!test_case.bias_data.empty()) {
        bias =
            tf.make({static_cast<int32_t>(test_case.C)}, test_case.bias_data);
      }
      Tensor out0 = tf.zeros(test_case.sizes);
      Tensor out1 = tf.zeros(
          {static_cast<int32_t>(test_case.N),
           static_cast<int32_t>(test_case.group)});
      Tensor out2 = tf.zeros(
          {static_cast<int32_t>(test_case.N),
           static_cast<int32_t>(test_case.group)});

      ET_EXPECT_KERNEL_FAILURE(
          context_,
          op_native_group_norm_out(
              in,
              weight,
              bias,
              test_case.N,
              test_case.C,
              test_case.HxW,
              test_case.group,
              test_case.eps,
              out0,
              out1,
              out2));
    }
  }

  // Test cases with imcompatible types.
  template <ScalarType DTYPE>
  void run_int_test_cases() {
    std::vector<NativeGroupNormTestCase<DTYPE>> test_cases = {
        {
            // Cannot be represented by a type other than float.
            {2, 4, 2, 2}, // sizes
            {1, 0,  -1, -1, 4, 0, 2,  -2, 1, 0,  -1, -1, 4, 0, 2, -2, 1,
             0, -1, -1, 4,  0, 2, -2, 1,  0, -1, -1, 4,  0, 2, -2}, // input_data
            {1, 1, 1, 1}, // weights
            {0, 0, 0, 0}, // bias
            2, // N
            4, // C
            4, // HxW
            2, // group
            1, // eps
            {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
             0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, // expected_data
            {0, 0, 0, 0}, // expected_mean_data
            {1, 1, 1, 1}, // expected_rstd_data
        },
    };
    run_death_test_cases(test_cases);
  }
};

/// Describes a test case, using tensors of the specified DTYPE.
TEST_F(OpNativeGroupNormTest, DoubleTensors) {
  run_floating_point_test_cases<ScalarType::Double>();
}

TEST_F(OpNativeGroupNormTest, FloatTensors) {
  run_floating_point_test_cases<ScalarType::Float>();
}

TEST_F(OpNativeGroupNormTest, HalfTensors) {
  run_floating_point_test_cases<ScalarType::Half>();
}

TEST_F(OpNativeGroupNormTest, BFloat16Tensors) {
  run_floating_point_test_cases<ScalarType::BFloat16>();
}

TEST_F(OpNativeGroupNormTest, IntTensorsDies) {
  // Cannot be represented by a type other than float.
  run_int_test_cases<ScalarType::Int>();
}
